home *** CD-ROM | disk | FTP | other *** search
/ NeXT Education Software Sampler 1992 Fall / NeXT Education Software Sampler 1992 Fall.iso / Programming / Classes / Neural-Network / Neuron.m < prev    next >
Encoding:
Text File  |  1992-07-29  |  4.9 KB  |  221 lines

  1. /* =======================================================
  2.     Neural Network Classes for the NeXT Computer
  3.     Written by: Ralph Zazula
  4.                     University of Arizona - Fall 1991
  5.                     zazula@pri.com (NeXT Mail)
  6. ==========================================================*/
  7. /*$Log:    Neuron.m,v $
  8. Revision 1.4  92/01/14  21:19:46  zazula
  9. Check in before starting HashTable mod
  10.  
  11. Revision 1.3  92/01/02  14:04:31  zazula
  12. Faster linked-list for connections
  13. No more Storage object
  14.  
  15. Revision 1.2  92/01/02  12:41:34  zazula
  16. Initial version - support for stochastic networks via temperature T
  17. */
  18. #import "Neuron.h"
  19. #import <appkit/nextstd.h>
  20. #import "math.h"
  21.  
  22.  
  23. //----------------------------------------------------------
  24.  
  25. @implementation Neuron
  26.  
  27. - inputs { return inputs; }
  28. - setType:(int)type { nodeType = type; return self; }
  29. - (int)getType { return nodeType; }
  30. - setTemp:(double)newT { T = newT; return self; }
  31. - (double)getTemp { return T; }
  32. - setRandom:theRandom { random = theRandom; return self; }
  33. - setSymmetric:(BOOL)sym { Symmetric = sym; return self; }
  34. - (BOOL)getSymmetric { return Symmetric; }
  35.  
  36. //----------------------------------------------------------
  37.  
  38. - (double)activation:(double)net
  39. {
  40.    double   temp;
  41.     
  42.     if(random == nil) random = [[Random alloc] init];
  43.    switch (nodeType) {
  44.    case Binary :
  45.       if(T > 0.0)
  46.          temp = ([random percent] <= 1.0/(1.0+exp(-2*net/T))) ? 1.0 : 0.0;
  47.       else
  48.          temp = (net > 0.5) ? 1.0 : 0.0;
  49.       break;
  50.    case Sigmoid : 
  51.       temp = 1.0/(1.0+exp(-net));
  52.       break;
  53.    case Sign : 
  54.       if(T > 0.0)
  55.          temp = ([random percent] <= 1.0/(1.0+exp(-2*net/T))) ? 1.0 : -1.0;
  56.       else
  57.          temp = (net > 0.0) ? 1.0 : -1.0;
  58.       break;
  59.    case Tanh :
  60.         if(T > 0.0)
  61.             temp = tanh(net/T);
  62.         else
  63.           temp = tanh(net);
  64.       break;
  65.    }
  66.    
  67.    return temp;
  68. }
  69.  
  70. //----------------------------------------------------------
  71.  
  72. - init
  73. {
  74.    [super init];
  75.    lastOutput = 0.0;
  76.    nodeType = Sigmoid;        // default node type
  77.    T = 0.0;                   // default temperature
  78.       head = tail = NULL;            // initialize the linked-list of connections
  79.     Symmetric = NO;                // default Symmetric connection status
  80.     
  81.    return self;
  82. }
  83.  
  84. //-----------------------------------------------------------
  85.  
  86. - step
  87. // update the output value based on our inputs
  88. {
  89.    int i = 0;
  90.    connection *C;
  91.    double temp=0.0;     // use temp variable to allow for feedback
  92.    
  93.     C = head;
  94.     while(C != NULL) {
  95.         temp += C->weight*[C->source lastOutput];
  96.         C = (connection *)C->next;
  97.     }
  98.     
  99.    lastOutput = [self activation:temp];
  100.    
  101.    return self;
  102. }
  103.  
  104. //-----------------------------------------------------------
  105.  
  106. - (double)lastOutput
  107. {
  108.    return lastOutput;
  109. }
  110.  
  111. //-----------------------------------------------------------
  112.  
  113. - connect:sender
  114. {
  115.     if(random == nil) random = [[Random alloc] init];
  116.    return [self connect:sender withWeight:[random percent]/10.0];
  117. }
  118.  
  119. //-----------------------------------------------------------
  120.  
  121. - connect:sender withWeight:(double)weight
  122. //
  123. // adds sender to the list of inputs
  124. // we should check to make sure sender is a Neruon
  125. // also need to check if it is already in the list
  126. //
  127. {
  128.    connection *C;
  129.    
  130.    C = (connection *)malloc(sizeof(connection));
  131.     if(head == NULL) {
  132.         head = C;
  133.     }
  134.     else {
  135.         tail->next = C;
  136.     }
  137.     tail = C;
  138.    C->source = sender;
  139.    C->weight = weight;
  140.     C->next   = NULL;
  141.       
  142.    return self;
  143. }
  144.  
  145. //-----------------------------------------------------------
  146.  
  147. - (double)getWeightFor:source
  148. {
  149.    int i=0;
  150.    connection *C;
  151.  
  152.     C = head;
  153.     while((C != NULL) && (C->source != source))
  154.         C = (connection *)C->next;
  155.         
  156.    if(C != NULL) {            // if C==NULL, source isn't an input
  157.       return C->weight;
  158.    }
  159.    else {
  160.       fprintf(stderr,"connection not found in getWeightFor:\n");
  161.       return NAN;
  162.    }
  163.     
  164. }
  165.  
  166. //-----------------------------------------------------------
  167.  
  168. - setWeightFor:source to:(double)weight
  169. {
  170.    int i=0;
  171.    connection *C;
  172.    
  173.     C = head;
  174.     while((C != NULL) && (C->source != source))
  175.         C = (connection *)C->next;
  176.         
  177.    if(C != NULL) {            // if C==NULL, source isn't an input
  178.       C->weight = weight;
  179.       return self;
  180.    }
  181.    else {
  182.       fprintf(stderr,"connection not found in setWeightFor:to:\n");
  183.       return nil;
  184.    }
  185. }
  186.  
  187. //-----------------------------------------------------------
  188.  
  189. - setOutput:(double)output
  190. {
  191.    lastOutput = output;
  192.    
  193.    return self;
  194. }
  195. //-----------------------------------------------------------
  196.  
  197. - changeWeightFor:source by:(double)delta
  198. {
  199.    int i=0;
  200.    connection *C;
  201.    
  202.     C = head;
  203.     while((C != NULL) && (C->source != source))
  204.         C = (connection *)C->next;
  205.         
  206.    if(C != NULL) {            // if C==NULL, source isn't an input
  207.       C->weight += delta;
  208.         if(Symmetric) // for symmetric connections
  209.            [source setWeightFor:self to:C->weight];
  210.       return self;
  211.    }
  212.    else {
  213.       fprintf(stderr,"connection not found in changeWeightfor:by:\n");
  214. //      printf("connection not found in changeWeightfor:by:\n");
  215.       return nil;
  216.    }
  217. }
  218.  
  219.  
  220. @end
  221.